#This script plots the cosine similarity scores for the synthetic articles
#It does so using pre-trained embeddings and transformation matrices
#It produces Figure 5 in the paper
library(ggplot2)
library(dplyr)
library(ggthemes)

results_list <- readRDS("data/output/cos_sims_synthetic_gpt-3.5-turbo/cos_sims_synthetic_all.rds")

# Combine the list of results into one data frame
cos_simsdf_combined <- bind_rows(results_list)
# Convert 'group' to a numeric value if it's a factor, to plot in order
cos_simsdf_combined$group <- as.numeric(as.character(cos_simsdf_combined$group))
# Recode so that 0 corresponds to break point
cos_simsdf_combined$group <- cos_simsdf_combined$group-11
# Generate breakpoint var. 
cos_simsdf_combined$groupbin <- ifelse(cos_simsdf_combined$group>=0, 1,0)
ordered_countries <- unique(cos_simsdf_combined$country_code)  # Replace with ordered factor if needed

# Create a named vector for language codes
language_codes <- c(
  English = "en",
  French = "fr",
  Spanish = "es",
  Russian = "ru",
  Mandarin = "zh-CN",
  Arabic = "ar",
  Japanese = "ja",
  Korean = "ko"
)

# Create a new column with language names instead of codes
cos_simsdf_combined$language_name <- names(language_codes)[match(cos_simsdf_combined$country_code, language_codes)]

# Ensure language names are ordered as per your original order
ordered_languages <- names(language_codes)[match(ordered_countries, language_codes)]
cos_simsdf_combined$language_name <- factor(cos_simsdf_combined$language_name, levels = ordered_languages)

colors_for_countries <- c("#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")

# Plot the data with language names
ggplot(cos_simsdf_combined, aes(x = group, y = val, col = language_name, group = groupbin)) +
  geom_point(alpha = 0.3, size = 2) +
  geom_smooth(method = "lm", size = 1.5, se = F, alpha =.3) +
  theme_tufte(base_family = "Helvetica") +
  geom_vline(xintercept = 0) +
  labs(x = "Time var.", y = "Cosine similarity, POLITFIG : opposition index", color = "Language") +
  ylim(-.2, 0.2) +
  scale_color_manual(values = colors_for_countries) +
  theme(legend.position = "none", axis.text.x = element_text(size=20), axis.text.y = element_text(size=20),
        axis.title.x = element_text(size=15), axis.title.y = element_text(size=15),
        legend.text=element_text(size=15), legend.title = element_text(size = 20),
        panel.border = element_rect(colour = "black", fill=NA, size=1),
        plot.background = element_rect(fill = "white", colour = NA),
        panel.grid.major = element_line(size = 0.1, linetype = "solid"),
        panel.grid.minor = element_line(size = 0.1, linetype = "solid"),
        strip.text = element_text(size = 24)) +
  facet_wrap(~ language_name, ncol = 4)

ggsave("plots/fig5.png", units = "in", width = 15, height = 5, dpi = 300)
